Skip to content

fix: jax reducers returning incorrect output values or lengths #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 28, 2025

Conversation

ikrommyd
Copy link
Collaborator

@ikrommyd ikrommyd commented Apr 14, 2025

Needs more work and I'd appreciate any help @ianna @pfackeldey.
I'm adding a test that tests all the reducers via parametrization.

Definitely needs #3457

This PR should fix #3456, #3462, #3463 and #3465

@ikrommyd
Copy link
Collaborator Author

I'm currently skipping the argmin argmax tests because of #3463 but that should change.

@ikrommyd
Copy link
Collaborator Author

# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)

So we can't take the product of an array with negative numbers? The logarithm will just NaN the output

@ikrommyd
Copy link
Collaborator Author

# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)

So we can't take the product of an array with negative numbers? The logarithm will just NaN the output

I may have fixed that in 9874810. I still haven't made ak.any to behave like the cpu case when segments have zeros though.

@ianna
Copy link
Collaborator

ianna commented Apr 14, 2025

@ikrommyd - impressive! There are only 5 tests that fails: 3 for ak.any, 1 for ak.count, and 1 for ak.sum. It looks like the latter is failing with a boolean dtype. Perhaps, return should make an int?

Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into it!

@ikrommyd
Copy link
Collaborator Author

ikrommyd commented Apr 14, 2025

So I made ci pass but there are a few things that are to be done for sure

  1. The code needs refactoring, I don't like how it looks at all, it's very hacky
  2. The reducer tests should definitely test more input array cases
  3. We need to make sure I'm not breaking something that should work but is untested due to the smaller amount of tests for the jax backend

Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - just some minor comments. I agree, JAX backend needs to be thoroughly tested. Perhaps, our fellow could take over? @pfackeldey - when do we discuss his project? Thanks.

@ikrommyd ikrommyd marked this pull request as ready for review April 27, 2025 06:01
@ikrommyd ikrommyd requested review from ianna and pfackeldey April 27, 2025 06:02
Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - Great! Thanks! The tests pass, I'm merging it now.

@ianna ianna merged commit 993c6f4 into main Apr 28, 2025
42 checks passed
@ianna ianna deleted the ikrommyd/jax-backend-reducers branch April 28, 2025 09:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants